# WHEN NO LONGER USING COLAB, PUT INTO A DIFFERENT FILE PLS
# Pre-processing the data (code: Lewis G)

import os
import h5py
import numpy as np
import re
from PIL import Image

# Pre-processing images
def load_jpgs(path, size=(224, 224)):
    """
    Load all jpgs on a path into a numpy array, resizing to a given image size
    """
    fnames = os.listdir(path)
    imgs = []
    i = 0
    if i<1500:
        for f in fnames:
            f= path + '/'+f
            if (os.path.isfile(f) and os.path.getsize(f) > 0):
                if not re.match('.+(jpg|jpeg|JPEG|JPG)', f):
                    continue
                try:
                    #image = Image.open(os.path.join(path, f))
                    image = Image.open(f)
                except OSError:
                    continue  # ignore corrupt files
                data = list(image.getdata())
                im = Image.new(image.mode, image.size)
                im.putdata(data)
                if im.mode != 'RGB':
                    im = im.convert('RGB')
                im = crop_center_or_reshape(im, size)
                img = 2 * (np.asarray(im) / 255) - 1
                #img= np.asarray(im)
                imgs.append(img)
                i= i+1

    return np.array(imgs)


def crop_center_or_reshape(im, size):
    tw, th = size

    im.thumbnail((int(tw * 1.5), int(th + 1.5)))  # heuristic
    iw, ih = im.size

    left = np.ceil((iw - tw) / 2)
    right = iw - np.floor((iw - tw) / 2)
    top = np.ceil((ih - th) / 2)
    bottom = ih - np.floor((ih - th) / 2)
    im = im.crop((left, top, right, bottom))
    if im.size != size:
        raise RuntimeError
    return im


# Lewis' function to fix naming problems
def gen_save_name(basename = os.getcwd()):
    """
    Generate a unique name for a saved file to avoid overwrites.
    """
    fname, suffix = basename.split('.')  # just assume this is true.
    qualifier = 1
    unique_fname = fname
    while (os.path.exists(unique_fname + '.' + suffix)):
        unique_fname = fname + '_{}'.format(qualifier)
        qualifier += 1
    return unique_fname + '.' + suffix


def load_or_create_dataset(directory):
    if not os.path.exists(directory):
        cat_path = os.getcwd() + '/PetImages/Cat'

        cats = load_jpgs(path=cat_path,
                         size=(224, 224))
        catlabel = np.zeros(cats.shape[0])

        dog_path = os.getcwd() + '/PetImages/Dog'

        dogs = load_jpgs(path=dog_path,
                         size=(224, 224))
        doglabel = np.ones(dogs.shape[0])

        data = np.concatenate([cats, dogs])
        labels = np.concatenate([catlabel, doglabel])

        del cats, dogs, catlabel, doglabel

        inds = np.random.permutation(data.shape[0])

        X = data.astype(np.float32)
        #Y = K.utils.to_categorical(labels)
        Y = labels

        del data, labels

        # shuffle data
        X = X[inds]
        Y = Y[inds]

        N = X.shape[0]
        split = int(0.8 * N)

        X_train = X[:split]
        Y_train = Y[:split]

        X_test = X[split:]
        Y_test = Y[split:]

        # write to database file to avoid this crap later
        with h5py.File(directory, 'w') as f:
            tr = f.create_group('train')
            te = f.create_group('test')
            tr.create_dataset('X', data=X_train)
            tr.create_dataset('Y', data=Y_train)

            te.create_dataset('X', data=X_test)
            te.create_dataset('Y', data=Y_test)
        return X_train, Y_train, X_test, Y_test
    else:
        with h5py.File(directory, 'r') as f:
            X_train = f['train']['X'].value
            Y_train = f['train']['Y'].value

            X_test = f['test']['X'].value
            Y_test = f['test']['Y'].value

        return X_train, Y_train, X_test, Y_test